Compare binding to different receptor ligands¶

In [1]:
import itertools

import altair as alt

import pandas as pd

_ = alt.data_transformers.disable_max_rows()
In [2]:
# this cell is tagged parameters for `papermill` parameterization

entry_293T_human_Mxra8 = None
binding_human_Mxra8 = None
binding_mouse_Mxra8 = None
addtl_site_annotations = None
site_numbering_map = None

mut_corr_chart_html = None
site_corr_chart_html = None
site_chart_html = None
In [3]:
# Parameters
entry_csv = "results/func_effects/averages/293T-Mxra8_entry_func_effects.csv"
binding_human_Mxra8 = "results/receptor_affinity/averages/human_Mxra8_mut_effect.csv"
binding_mouse_Mxra8 = "results/receptor_affinity/averages/mouse_Mxra8_mut_effect.csv"
addtl_site_annotations = "data/addtl_site_annotations.csv"
site_numbering_map = "data/site_numbering_map.csv"
mut_corr_chart_html = "results/compare_human_mouse_mxra8_mut_binding_corr.html"
site_corr_chart_html = "results/compare_human_mouse_mxra8_site_binding_corr.html"
site_chart_html = "results/compare_human_mouse_mxra8_site_chart.html"
In [4]:
# Additional hardcoded parameters

min_entry = -4
min_entry_std = 2.25
entry_name = "entry in 293T-Mxra8 cells"
min_times_seen = 2

ligands = {"mouse_Mxra8": "mouse Mxra8", "human_Mxra8": "human Mxra8"}
binding_csvs = {
    "human_Mxra8": binding_human_Mxra8,
    "mouse_Mxra8": binding_mouse_Mxra8,
}
binding_csv_col_names = {"human_Mxra8": "Mxra8", "mouse_Mxra8": "Mxra8"}
max_binding_stds = {"human_Mxra8": 2.5, "mouse_Mxra8": 2.25}

addtl_site_annotations_cols = {
    "domain": "domain",
    "contacts": "Mxra8 contact",
}

assert len(ligands) == 2, "saving for corr charts only works for 2 ligands currently"

Read the data¶

In [5]:
# read the data

print(f"Reading cell entry from {entry_csv=}")
data_df = (
    pd.read_csv(entry_csv)
    .query("times_seen >= @min_times_seen")
    .query("effect_std <= @min_entry_std")
    .assign(mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"])
    [["site", "wildtype", "mutant", "effect"]]
    .rename(columns={"effect": "entry"})
)

for ligand in ligands:
    print(f"Reading binding to {ligand=} from {binding_csvs[ligand]=}")
    max_std = max_binding_stds[ligand]
    col_name = binding_csv_col_names[ligand]
    bind_df = (
        pd.read_csv(binding_csvs[ligand])
        .query("times_seen >= @min_times_seen")
        .query("frac_models == 1")
        .query(f"`{col_name} binding_std` <= @max_std")
        .rename(columns={f"{col_name} binding_median": ligand})
    )
    bind_rep_cols = bind_df.columns[11: ].tolist()
    bind_df = (
        bind_df
        .assign(
            label=lambda x: x.apply(
                lambda r: f"{r[ligand]:.2f} ({', '.join(str(round(r[c], 2)) for c in bind_rep_cols)})",
                axis=1,
            )
        )
        .rename(columns={"label": f"{ligand}_label"})
        [["site", "wildtype", "mutant", ligand, f"{ligand}_label"]]
    )
    data_df = data_df.merge(
        bind_df, how="left", on=["site", "mutant", "wildtype"], validate="1:1"
    )

print(f"Adding sequential site from {site_numbering_map=}")
data_df = data_df.merge(
    pd.read_csv(site_numbering_map).rename(columns={"reference_site": "site"})[
        ["site", "sequential_site", "region"]
    ],
    on="site",
    validate="many_to_one",
)

print(f"Adding site annotations from {addtl_site_annotations=}")
data_df = data_df.merge(
    (
        pd.read_csv(addtl_site_annotations)
        [["sequential_site"] + list(addtl_site_annotations_cols)]
        .rename(columns=addtl_site_annotations_cols)
    ),
    on="sequential_site",
    validate="many_to_one",
    how="left",
)

data_df = (
    data_df
    .query("wildtype != mutant")
    .assign(
        mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"],
        **{"Mxra8 contact": lambda x: x["Mxra8 contact"].fillna("no")},
    )
    .sort_values(["sequential_site", "mutant"])
    .reset_index(drop=True)
)

data_df
Reading cell entry from entry_csv='results/func_effects/averages/293T-Mxra8_entry_func_effects.csv'
Reading binding to ligand='mouse_Mxra8' from binding_csvs[ligand]='results/receptor_affinity/averages/mouse_Mxra8_mut_effect.csv'
Reading binding to ligand='human_Mxra8' from binding_csvs[ligand]='results/receptor_affinity/averages/human_Mxra8_mut_effect.csv'
Adding sequential site from site_numbering_map='data/site_numbering_map.csv'
Adding site annotations from addtl_site_annotations='data/addtl_site_annotations.csv'
Out[5]:
site wildtype mutant entry mouse_Mxra8 mouse_Mxra8_label human_Mxra8 human_Mxra8_label sequential_site region domain Mxra8 contact mutation
0 -1(E3) M I -7.5410 NaN NaN NaN NaN 1 E3 NaN no M-1(E3)I
1 -1(E3) M T -7.5630 NaN NaN NaN NaN 1 E3 NaN no M-1(E3)T
2 1(E3) S A -1.0250 -0.11910 -0.12 (-0.06, -0.18) 0.04762 0.05 (0.06, 0.03) 2 E3 E3 no S1(E3)A
3 1(E3) S C -0.7132 -0.21170 -0.21 (-0.44, 0.01) -0.73310 -0.73 (-0.61, -0.85) 2 E3 E3 no S1(E3)C
4 1(E3) S D 0.1852 0.02613 0.03 (0.02, 0.04) -0.21540 -0.22 (-0.21, -0.22) 2 E3 E3 no S1(E3)D
... ... ... ... ... ... ... ... ... ... ... ... ... ...
18957 439(E1) H V -0.4753 NaN NaN NaN NaN 988 E1 E1-cytoplasmic no H439(E1)V
18958 439(E1) H W -0.2051 0.23070 0.23 (-0.03, 0.49) -0.28620 -0.29 (-0.64, 0.07) 988 E1 E1-cytoplasmic no H439(E1)W
18959 439(E1) H Y -0.2293 -0.01344 -0.01 (-0.12, 0.1) -0.24560 -0.25 (-0.29, -0.2) 988 E1 E1-cytoplasmic no H439(E1)Y
18960 440(E1) * Q -3.3990 0.13000 0.13 (-0.02, 0.28) -1.51300 -1.51 (-2.55, -0.48) 989 E1 NaN no *440(E1)Q
18961 440(E1) * Y -1.0960 0.64660 0.65 (1.12, 0.17) 0.59920 0.60 (1.22, -0.02) 989 E1 NaN no *440(E1)Y

18962 rows × 13 columns

Simple correlation of binding to different ligands across all mutations¶

In [6]:
# plot the data

site_selection = alt.selection_point(on="mouseover", empty=False, fields=["site"])

mut_selection = alt.selection_point(on="mouseover", empty=False, fields=["mutation"])

min_entry_slider = alt.param(
    name="min_entry_slider",
    bind=alt.binding_range(
        min=data_df["entry"].min(),
        max=0,
        name=f"minimum {entry_name}",
    ),
    value=min_entry,
)

mut_corr_base = alt.Chart(
    data_df[
        ["mutation", "entry", "site"]
        + list(ligands)
        + [f"{lig}_label" for lig in ligands]
    ]
)

for ligand1, ligand2 in itertools.combinations(ligands, 2):
    
    mut_corr_chart = (
        mut_corr_base
        .add_params(site_selection, mut_selection, min_entry_slider)
        .transform_filter(alt.datum["entry"] >= min_entry_slider)
        .encode(
            alt.X(
                ligand1,
                title=f"binding to {ligands[ligand1]}",
                scale=alt.Scale(nice=False, padding=5),
            ),
            alt.Y(
                ligand2,
                title=f"binding to {ligands[ligand2]}",
                scale=alt.Scale(nice=False, padding=5),
            ),
            color=alt.condition(site_selection, alt.value("red"), alt.value("gray")),
            opacity=alt.condition(site_selection, alt.value(0.9), alt.value(0.15)),
            size=alt.condition(site_selection, alt.value(55), alt.value(40)),
            strokeWidth=alt.condition(mut_selection, alt.value(3), alt.value(0.6)),
            tooltip=[
                "mutation",
                alt.Tooltip("entry", format=".2f", title=entry_name),
                alt.Tooltip(f"{ligand1}_label", title=ligands[ligand1]),
                alt.Tooltip(f"{ligand2}_label", title=ligands[ligand2]),
            ],
        )
        .mark_circle(stroke="black")
        .properties(
            width=175,
            height=175,
        )
        .configure_axis(grid=False)
    )

    display(mut_corr_chart)

    print(f"Saving to {mut_corr_chart_html}")
    mut_corr_chart.save(mut_corr_chart_html)
Saving to results/compare_human_mouse_mxra8_mut_binding_corr.html

Plot site effects on binding¶

We pre-filter on the entry cutoff, and then get the summed positive and negative effects at each site scaled by the max across all sites for the positive and negative effect for that ligand:

In [7]:
data_filtered_df = data_df.query("entry >= @min_entry")

site_df = (
    data_filtered_df
    .melt(
        id_vars=["site", "sequential_site", "wildtype", "region", "Mxra8 contact"],
        value_vars=ligands,
        var_name="ligand",
        value_name="effect",
    )
    .groupby(
        ["ligand", "site", "sequential_site", "wildtype", "region", "Mxra8 contact"],
        as_index=False,
        dropna=False,
    )
    .aggregate(
        positive_effect=pd.NamedAgg("effect", lambda s: s.clip(lower=0).sum()),
        negative_effect=pd.NamedAgg("effect", lambda s: s.clip(upper=0).sum()),
        absolute_effect=pd.NamedAgg("effect", lambda s: s.abs().sum()),
    )
    # scale by min / max
    .assign(
        norm=lambda x: pd.concat(
            [
                x.groupby("ligand")["positive_effect"].transform("max"),
                -x.groupby("ligand")["negative_effect"].transform("min"),
            ],
            axis=1
        ).max(axis=1),
        positive_effect=lambda x: x["positive_effect"] / x["norm"],
        negative_effect=lambda x: x["negative_effect"] / x["norm"],
        absolute_effect=lambda x: x["absolute_effect"] / x.groupby("ligand")["absolute_effect"].transform("max"),
    )
    .drop(columns="norm")
)
In [8]:
chart_width = 600

site_binding_chart = (
    alt.Chart(
        site_df.assign(ligand_name=lambda x: "binding to " + x["ligand"].map(ligands))
    )
    .encode(
        alt.X(
            "site",
            sort=alt.SortField("sequential_site"),
            axis=alt.Axis(
                values=site_df[["sequential_site", "site"]].sort_values("sequential_site")["site"].iloc[50::130],
                labelAngle=0,
                titleFontSize=11,
            ),
        ),
        alt.Y("positive_effect", title=None, scale=alt.Scale(nice=False, padding=4)),
        alt.Y2("negative_effect", title=None),
        alt.Color(
            "Mxra8 contact",
            scale=alt.Scale(
                domain=["no", "wrapped", "intraspike", "interspike"],
                range=["gray", "red", "purple", "orange"],
            ),
        ),
        alt.Row(
            "ligand_name",
            title=None,
            header=alt.Header(labelFontSize=11, labelFontStyle="bold", labelPadding=2),
            spacing=5,
        ),
        tooltip=[
            "site",
            "wildtype",
            alt.Tooltip("positive_effect", format=".2f"),
            alt.Tooltip("negative_effect", format=".2f"),
            "Mxra8 contact",
        ],
    )
    .mark_bar(opacity=1, width=1)
    .properties(width=chart_width, height=128)
)

Make overlay bar with regions:

In [9]:
region_chart = (
    alt.Chart(site_df[["sequential_site", "region"]].drop_duplicates())
    .encode(
        alt.X("sequential_site:O", axis=None),
        alt.Color(
            "region",
            legend=None,
            scale=alt.Scale(range=["AliceBlue", "CadetBlue", "CadetBlue", "AliceBlue"])
        ),
    )
    .mark_rect(opacity=0.75, strokeWidth=0)
    .properties(width=chart_width)
)

text_df = site_df.groupby("region", as_index=False).aggregate(x=pd.NamedAgg("sequential_site", "mean"))

text_chart = (
    alt.Chart(text_df)
    .encode(
        alt.X(
            "x:Q",
            title=None,
            scale=alt.Scale(domain=(site_df["sequential_site"].min(), site_df["sequential_site"].max())),
            axis=None,
        ),
        alt.Text("region"),
    )
    .mark_text(fontWeight="bold", fontSize=11)
    .properties(width=chart_width, height=13)
)

overlay_chart = region_chart + text_chart

Combine overlay and site chart:

In [10]:
site_chart = (
    alt.vconcat(overlay_chart, site_binding_chart, spacing=1)
    .resolve_scale(color="independent")
    .configure_axis(grid=False)
    .configure_view(stroke="black", strokeOpacity=1, strokeWidth=1)
    .interactive(bind_x=True, bind_y=False)
)

site_chart

print(f"Saving to {site_chart_html}")
mut_corr_chart.save(site_chart_html)
Saving to results/compare_human_mouse_mxra8_site_chart.html

Plot correlations in site effects¶

In [11]:
site_corr_df = (
    site_df
    .melt(
        id_vars=["ligand", "site", "wildtype", "region", "Mxra8 contact"],
        value_vars=["positive_effect", "negative_effect", "absolute_effect"],
        var_name="metric",
        value_name="effect",
    )
    .pivot_table(
        index=["site", "wildtype", "region", "Mxra8 contact", "metric"],
        values="effect",
        columns="ligand",
    )
    .reset_index()
)
In [12]:
tooltip_cols = ["site", "wildtype", "region", "Mxra8 contact"]

for ligand1, ligand2 in itertools.combinations(ligands, 2):

    corrs = (
        site_corr_df
        .groupby("metric")
        [[ligand1, ligand2]]
        .corr()
        .reset_index(level=1)
        .query("ligand == @ligand1")
        [ligand2]
        .to_dict()
    )

    site_corr_chart = (
        alt.Chart(
            site_corr_df[tooltip_cols + [ligand1, ligand2, "metric"]]
            .assign(
                metric=lambda x: x["metric"].map(
                    {
                        metric: 
                            f"{metric.replace('_', ' ')} at site (r = {corrs[metric]:.2f})"
                        for metric in site_corr_df["metric"].unique()
                    }
                )
            )
        )
        .add_params(site_selection)
        .encode(
            alt.X(ligand1, title=ligands[ligand1], scale=alt.Scale(nice=False, padding=6)),
            alt.Y(ligand2, title=ligands[ligand2], scale=alt.Scale(nice=False, padding=6)),
            alt.Column(
                "metric",
                title=None,
                header=alt.Header(labelFontStyle="bold", labelFontSize=11, labelPadding=2),
            ),
            color=alt.condition(site_selection, alt.value("red"), alt.value("gray")),
            strokeWidth=alt.condition(site_selection, alt.value(3), alt.value(1)),
            size=alt.condition(site_selection, alt.value(60), alt.value(35)),
            opacity=alt.condition(site_selection, alt.value(1), alt.value(0.25)),
            tooltip=[
                *tooltip_cols,
                alt.Tooltip(ligand1, title=ligands[ligand1], format=".2f"),
                alt.Tooltip(ligand2, title=ligands[ligand2], format=".2f"),
            ],
        )
        .mark_circle(stroke="black")
        .resolve_scale(x="independent", y="independent")
        .configure_axis(grid=False)
        .properties(width=140, height=140)
    )

    display(site_corr_chart)

    print(f"Saving to {site_corr_chart_html}")
    site_corr_chart.save(site_corr_chart_html)
Saving to results/compare_human_mouse_mxra8_site_binding_corr.html
In [ ]: